#include <metal_stdlib>
#include <metal_math>

using namespace metal;

#define NUM_SIMDGROUP 32

METAL_FUNC uint indices_to_idx_2(uint2 indices, constant const size_t strides[2]) {
  return indices.x * strides[1] + indices.y * strides[0];
}

METAL_FUNC uint indices_to_idx_3(uint3 indices, constant const size_t strides[3]) {
  return indices.x * strides[2] + indices.y * strides[1] + indices.z * strides[0];
}

METAL_FUNC uint indices_to_idx_4(uint3 indices,
                                 constant const size_t shape[4], 
                                 constant const size_t strides[4]) {
  auto idx = indices.x * strides[3] + indices.y * strides[2];
  idx += (indices.z % shape[1]) * strides[1];
  indices.z /= shape[1];
  idx += indices.z * strides[0];
  return idx;
}

template <typename U>
struct MeanOfSquares {
  float simd_reduce(float val, size_t reduce_dim) {
    return simd_sum(val) / static_cast<float>(reduce_dim);
  }

  static constexpr constant float init = 0.0;

  // Operator
  float operator()(float acc, U a) {
    float a_f = static_cast<float>(a);
    return acc + a_f * a_f;
  }
};

template <typename U>
struct Sum {
  U simd_reduce(U val, size_t reduce_dim) {
    return simd_sum(val);
  }

  static constexpr constant U init = U(0);

  // Operator
  U operator()(U acc, U a) {
    return acc + a;
  }
};

template <typename U>
struct Min {
  template <typename T>
  T simd_reduce(T val, size_t reduce_dim) {
    return simd_min(val);
  }

  static constexpr constant U init = metal::numeric_limits<U>::infinity();

  // Operator
  U operator()(U a, U b) {
    return a < b ? a : b;
  }
};

template <typename U>
struct Max {
  template <typename T>
  T simd_reduce(T val, size_t reduce_dim) {
    return simd_max(val);
  }

  static constexpr constant U init = -metal::numeric_limits<U>::infinity();

  // Operator
  U operator()(U a, U b) {
    return a > b ? a : b;
  }
};



template <typename U>
struct Prod {
  U simd_reduce(U val, size_t reduce_dim) {
    return simd_product(val);
  }

  static constexpr constant U init = U(1);

  // Operator
  U operator()(U acc, U a) {
    return acc * a;
  }
};


template<typename F, typename Op>  
[[kernel]] void reduce_nd3(
                device const void *input_b,
                device void *output_b,
                constant const size_t input_shape[3], 
                constant const size_t input_strides[3],
                constant const size_t output_strides[3],
                uint3  tgpig[[threadgroup_position_in_grid]],
        		    uint  tiisg[[thread_index_in_simdgroup]],
        		    uint  tpsg[[threads_per_simdgroup]]
                ) {

    device const F *input = (device const F *)input_b;
    device F *output = (device F *)output_b;

    Op op = Op();

    size_t reduce_dim = input_shape[1];

    size_t out_idx = tgpig.x * output_strides[2] 
            + tgpig.y * output_strides[1] 
            + tgpig.z * output_strides[0];

    size_t base_in_idx = tgpig.x * input_strides[2] 
            + tgpig.z * input_strides[0];

    auto partial_acc = Op::init;
    for (size_t i = tiisg; i < reduce_dim; i += tpsg) {
        F el = input[base_in_idx + i * input_strides[1]];
        partial_acc = op(partial_acc, el);
    }
    auto acc = op.simd_reduce(partial_acc, reduce_dim);

    if (tiisg == 0) {
       output[out_idx] = acc;
    }
}

typedef decltype(reduce_nd3<float, Prod<float>>) reduce_nd3_t;

#define INSTANTIATE_REDUCE(name, op, tname, type)                    \
template [[host_name("nn_ops::reduce_" #name "_nd3_" #tname)]]       \
[[kernel]] reduce_nd3_t reduce_nd3<type, op<type>>;


INSTANTIATE_REDUCE(mean_of_squares, MeanOfSquares, f32, float)
INSTANTIATE_REDUCE(mean_of_squares, MeanOfSquares, f16, half)
INSTANTIATE_REDUCE(sum, Sum, f32, float)
INSTANTIATE_REDUCE(sum, Sum, f16, half)
INSTANTIATE_REDUCE(min, Min, f32, float)
INSTANTIATE_REDUCE(min, Min, f16, half)
INSTANTIATE_REDUCE(max, Max, f32, float)
INSTANTIATE_REDUCE(max, Max, f16, half)
INSTANTIATE_REDUCE(prod, Prod, f32, float)
INSTANTIATE_REDUCE(prod, Prod, f16, half)


template<typename F>  
[[kernel]] void rms_norm_nd3(
                device const void *input_b,
                constant void * eps_b,
                device void *output_b,
                constant const size_t shape[3], 
                constant const size_t strides[3],
                threadgroup float * shmem_f32 [[threadgroup(0)]],
                uint   tgpig[[threadgroup_position_in_grid]],
                ushort tpitg[[thread_position_in_threadgroup]],
                ushort sgitg[[simdgroup_index_in_threadgroup]],
                ushort tiisg[[thread_index_in_simdgroup]],
                ushort   ntg[[threads_per_threadgroup]]
                ) {
    if (sgitg == 0) {
        shmem_f32[tiisg] = 0.0f;
    }
    device const F* input = (device const F*) input_b;
    float eps = ((constant float *)eps_b)[0];
    device F * output = (device F*) output_b;

    size_t dim = shape[1];

    size_t base_idx = (tgpig % shape[2]) * strides[2] + (tgpig / shape[2]) * strides[0];

    float partial_acc = 0.0;
    for (size_t i = tpitg; i < dim; i += ntg) {
        float el = static_cast<float>(input[base_idx + i * strides[1]]);
        partial_acc += el * el;
    }

    partial_acc = simd_sum(partial_acc);
    threadgroup_barrier(mem_flags::mem_threadgroup);

    if (tiisg == 0) {
        shmem_f32[sgitg] = partial_acc;
    }

    threadgroup_barrier(mem_flags::mem_threadgroup);

    partial_acc = shmem_f32[tiisg];
    partial_acc = simd_sum(partial_acc);

    float mean_of_squares = partial_acc / dim;

    float norm = metal::rsqrt(mean_of_squares + eps);

    for (size_t i = tpitg; i < dim; i += ntg) {
        auto idx = base_idx + i * strides[1];
        output[idx] = input[idx] * norm;
    }
}

template<typename F, typename F4>  
[[kernel]] void rms_norm_nd2_l4(
        device const char *input_b,
        constant char * eps_b,
        device char *output_b,
        constant const size_t & n,
        constant const size_t & n_div_4, 
        constant const size_t & outer_stride,
        threadgroup float * shmem_f32 [[threadgroup(0)]],
        uint   tgpig[[threadgroup_position_in_grid]],
        ushort tpitg[[thread_position_in_threadgroup]],
        ushort sgitg[[simdgroup_index_in_threadgroup]],
        ushort tiisg[[thread_index_in_simdgroup]],
        ushort   ntg[[threads_per_threadgroup]]) {
    if (sgitg == 0) {
        shmem_f32[tiisg] = 0.0f;
    }

    device const F4 * x = (device const F4 *) (input_b + tgpig*outer_stride);
    float eps = ((constant float *)eps_b)[0];
    float sumf = 0.0f;

    // parallel sum
    for (size_t i = tpitg; i < n_div_4; i += ntg) {
        float4 el = static_cast<float4>(x[i]);
        sumf += dot(el, el);
    }
    sumf = simd_sum(sumf);

    threadgroup_barrier(mem_flags::mem_threadgroup);

    if (tiisg == 0) {
        shmem_f32[sgitg] = sumf;
    }

    threadgroup_barrier(mem_flags::mem_threadgroup);

    sumf = shmem_f32[tiisg];
    sumf = simd_sum(sumf);

    const float mean  = sumf/n;
    const float scale = 1.0f/sqrt(mean + eps);

    device F4 * y = (device F4 *) output_b + tgpig * n_div_4;
    for (size_t i = tpitg; i < n_div_4; i += ntg) {
        y[i] = x[i] * scale;
    }
}

typedef decltype(rms_norm_nd3<float>) rms_norm_nd3_t;
typedef decltype(rms_norm_nd2_l4<float, float4>) rms_norm_nd2_l4_t;

template [[host_name("nn_ops::rms_norm_nd3_f32")]] [[kernel]] rms_norm_nd3_t rms_norm_nd3<float>;
template [[host_name("nn_ops::rms_norm_nd3_f16")]] [[kernel]] rms_norm_nd3_t rms_norm_nd3<half>;
template [[host_name("nn_ops::rms_norm_nd2_l4_f32")]] [[kernel]] rms_norm_nd2_l4_t rms_norm_nd2_l4<float, float4>;
template [[host_name("nn_ops::rms_norm_nd2_l4_f16")]] [[kernel]] rms_norm_nd2_l4_t rms_norm_nd2_l4<half, half4>;

struct Sigmoid {
  template <typename T>
  T operator()(T x) {
    auto y = 1 / (1 + metal::exp(-metal::abs(x)));
    return (x < 0) ? 1 - y : y;
  }
};

template<typename T>
[[kernel]] void silu(device const void *input_b [[buffer(0)]],
                             device void *output_b [[buffer(1)]],
                             uint tpig[[thread_position_in_grid]]) {
   device const T *input = (device const T *)input_b;
   device T *output = (device T *)output_b;

   output[tpig] = Sigmoid()(static_cast<float>(input[tpig])) * input[tpig];
}

typedef decltype(silu<float>) silu_t;

template<typename T4>
[[kernel]] void silu_4(
        device const void * input_b,
        device       void * output_b,
        uint tpig[[thread_position_in_grid]]) {
    device const T4 *input = (device const T4 *) input_b;
    device T4 *output = (device T4 *) output_b;
    auto x = input[tpig];
    output[tpig] = x / (1.0f + exp(-x));
}

typedef decltype(silu_4<float4>) silu_4_t;

template [[host_name("nn_ops::silu_f32")]] [[kernel]] silu_t silu<float>;
template [[host_name("nn_ops::silu_f16")]] [[kernel]] silu_t silu<half>;

template [[host_name("nn_ops::silu_4_f32")]] [[kernel]] silu_4_t silu_4<float4>;
template [[host_name("nn_ops::silu_4_f16")]] [[kernel]] silu_4_t silu_4<half4>;

template<typename F>  
[[kernel]] void softmax_nd3(
                device const void *input_b,
                device void *output_b,
                constant const size_t shape[3], 
                constant const size_t strides[3],
                uint3  tgpig[[threadgroup_position_in_grid]],
                uint  tiisg[[thread_index_in_simdgroup]],
                uint  tpsg[[threads_per_simdgroup]]
                ) {

    device const F *input = (device const F *)input_b;
    device F *output = (device F *)output_b;

    size_t dim = shape[1];

    size_t base_idx = tgpig.x * strides[2] 
            + tgpig.z * strides[0];

    // Get max value on softmax dim
    float partial_max = -INFINITY;
    for (size_t i = tiisg; i < dim; i += tpsg) {
        auto idx = base_idx + i * strides[1];
        float el = static_cast<float>(input[idx]);
        partial_max = max(partial_max, el);
    }

    float axis_max = simd_max(partial_max);

    // Compute Sum(exp(x - max))
    float partial_norm = 0;
    for (size_t i = tiisg; i < dim; i += tpsg) {
        auto idx = base_idx + i * strides[1];
        float el = static_cast<float>(input[idx]);
        float exp_el = fast::exp(el - axis_max);
        partial_norm += exp_el;
        output[idx] = static_cast<F>(exp_el);
    }

    float axis_norm = simd_sum(partial_norm);
    float inv_axis_norm = 1.0 / axis_norm;

    for (size_t i = tiisg; i < dim; i += tpsg) {
        auto idx = base_idx + i * strides[1];
        float exp_el = static_cast<float>(output[idx]);
        output[idx] = static_cast<F>(exp_el * inv_axis_norm);
    }
}

typedef decltype(softmax_nd3<float>) softmax_nd3_t;

template [[host_name("nn_ops::softmax_nd3_f32")]] [[kernel]] softmax_nd3_t softmax_nd3<float>;
template [[host_name("nn_ops::softmax_nd3_f16")]] [[kernel]] softmax_nd3_t softmax_nd3<half>;

template<typename F>  
[[kernel]] void scaled_masked_softmax_nd3(
                device const void *input_b,
                device const void *mask_b,
                constant void *scale_b,
                device void *output_b,
                constant const size_t shape[3], 
                constant const size_t strides[3],
                constant const size_t mask_strides[3],
                constant const size_t out_strides[3],
                uint3  tgpig[[threadgroup_position_in_grid]],
                uint  tiisg[[thread_index_in_simdgroup]],
                uint  tpsg[[threads_per_simdgroup]]
                ) {

    device const F *input = (device const F *)input_b;
    device const F *mask = (device const F *)mask_b;
    F scale = ((constant F *)scale_b)[0];
    device F *output = (device F *)output_b;

    size_t reduce_dim = shape[2];

    size_t base_idx = tgpig.y * strides[1] 
            + tgpig.z * strides[0];

    size_t mask_base_idx = tgpig.y * mask_strides[1] 
            + tgpig.z * mask_strides[0];

    size_t base_out_idx = tgpig.y * out_strides[1] 
            + tgpig.z * out_strides[0];
    // Get max value on softmax reduce_dim after applying scale and mask
    float partial_max = -INFINITY;
    for (size_t i = tiisg; i < reduce_dim; i += tpsg) {
        auto idx = base_idx + i * strides[2];
        auto out_idx = base_out_idx + i * out_strides[2];
        auto mask_idx = mask_base_idx + i * mask_strides[2];
        output[out_idx] = input[idx] * scale + mask[mask_idx];
        float el = static_cast<float>(output[out_idx]);
        partial_max = max(partial_max, el);
    }

   float axis_max = simd_max(partial_max);

   // Compute Sum(exp(x - max))
   float partial_norm = 0;
   for (size_t i = tiisg; i < reduce_dim; i += tpsg) {
       auto out_idx = base_out_idx + i * out_strides[2];
       float el = static_cast<float>(output[out_idx]);
       float exp_el = fast::exp(el - axis_max);
       partial_norm += exp_el;
   }

   float axis_norm = simd_sum(partial_norm);
   float inv_axis_norm = 1.0 / axis_norm;

   for (size_t i = tiisg; i < reduce_dim; i += tpsg) {
       auto out_idx = base_out_idx + i * out_strides[2];
       float el = static_cast<float>(output[out_idx]);
       float exp_el = fast::exp(el - axis_max);
       output[out_idx] = static_cast<F>(exp_el * inv_axis_norm);
   }
}

typedef decltype(scaled_masked_softmax_nd3<float>) scaled_masked_softmax_nd3_t;

template [[host_name("nn_ops::scaled_masked_softmax_nd3_f32")]] [[kernel]] scaled_masked_softmax_nd3_t scaled_masked_softmax_nd3<float>;
template [[host_name("nn_ops::scaled_masked_softmax_nd3_f16")]] [[kernel]] scaled_masked_softmax_nd3_t scaled_masked_softmax_nd3<half>;

constant float GELU_COEF_A     = 0.044715f;
constant float SQRT_2_OVER_PI  = 0.79788456080286535587989211986876f;

template<typename F>  
[[kernel]] void gelu_approx(
                device const void *input_b,
                device void *output_b,
                uint tpig[[thread_position_in_grid]]
                ) {

    device const F *input = (device const F *)input_b;
    device F *output = (device F *)output_b;

    float x = static_cast<float>(input[tpig]);
    float output_f32 = 0.5 * x * (
      1.0 + precise::tanh(SQRT_2_OVER_PI
          *(x + GELU_COEF_A * metal::powr(x, 3))));
    output[tpig] = static_cast<F>(output_f32);
}

typedef decltype(gelu_approx<float>) gelu_approx_t;

template [[host_name("nn_ops::gelu_approx_f32")]] [[kernel]] gelu_approx_t gelu_approx<float>;
template [[host_name("nn_ops::gelu_approx_f16")]] [[kernel]] gelu_approx_t gelu_approx<half>;

template<typename F>  
[[kernel]] void gelu_approx_fast(
                device const void *input_b,
                device void *output_b,
                uint tpig[[thread_position_in_grid]]
                ) {

    device const F *input = (device const F *)input_b;
    device F *output = (device F *)output_b;

    float x = static_cast<float>(input[tpig]);
    float output_f32 = 0.5 * x * (
      1.0 + precise::tanh(SQRT_2_OVER_PI
          *(x + GELU_COEF_A * metal::powr(x, 2))));
    output[tpig] = static_cast<F>(output_f32);
}

typedef decltype(gelu_approx_fast<float>) gelu_approx_fast_t;

template [[host_name("nn_ops::gelu_approx_fast_f32")]] [[kernel]] gelu_approx_fast_t gelu_approx_fast<float>;
template [[host_name("nn_ops::gelu_approx_fast_f16")]] [[kernel]] gelu_approx_fast_t gelu_approx_fast<half>;



template<typename T>  
[[kernel]] void apply_rope_nd2(             
      device const void *input_b [[buffer(0)]],
      device const void *cos_b [[buffer(1)]],
      device const void *sin_b [[buffer(2)]],                 
      device void *output_b [[buffer(3)]],                        
      constant const size_t * shape [[buffer(4)]],
      constant const size_t * strides [[buffer(5)]],
      constant const size_t * cos_sin_strides [[buffer(6)]],
      constant const size_t * out_strides [[buffer(7)]],
      uint2 tpig[[thread_position_in_grid]]
) {
  device const T *input = (device const T *)input_b;
  device const T *cos = (device const T *)cos_b;
  device const T *sin = (device const T *)sin_b;

  device T* output = (device T *) output_b;

  uint2 rotated_tpig = tpig;
  rotated_tpig.x += shape[1] / 2;

  auto idx = indices_to_idx_2(tpig, strides);
  auto rot_idx = indices_to_idx_2(rotated_tpig, strides);
  auto out_idx = indices_to_idx_2(tpig, out_strides);
  auto out_rot_idx = indices_to_idx_2(rotated_tpig, out_strides);

  auto cos_sin_idx = indices_to_idx_2(tpig, cos_sin_strides);
  auto rot_cos_sin_idx = indices_to_idx_2(rotated_tpig, cos_sin_strides);

  output[out_idx] = input[idx] * cos[cos_sin_idx] - input[rot_idx] * sin[cos_sin_idx];
  output[out_rot_idx] = input[rot_idx] * cos[rot_cos_sin_idx]
          + input[idx] * sin[rot_cos_sin_idx];
}

template<typename T>  
[[kernel]] void apply_rope_nd3(             
      device const void *input_b [[buffer(0)]],
      device const void *cos_b [[buffer(1)]],
      device const void *sin_b [[buffer(2)]],                 
      device void *output_b [[buffer(3)]],                        
      constant const size_t * shape [[buffer(4)]],
      constant const size_t * strides [[buffer(5)]],
      constant const size_t * cos_sin_strides [[buffer(6)]],
      constant const size_t * out_strides [[buffer(7)]],
      uint3 tpig[[thread_position_in_grid]]
) {
  device const T *input = (device const T *)input_b;
  device const T *cos = (device const T *)cos_b;
  device const T *sin = (device const T *)sin_b;

  device T* output = (device T *) output_b;

  uint3 rotated_tpig = tpig;
  rotated_tpig.x += shape[2] / 2;

  auto idx = indices_to_idx_3(tpig, strides);
  auto rot_idx = indices_to_idx_3(rotated_tpig, strides);
  auto out_idx = indices_to_idx_3(tpig, out_strides);
  auto out_rot_idx = indices_to_idx_3(rotated_tpig, out_strides);

  auto cos_sin_idx = indices_to_idx_3(tpig, cos_sin_strides);
  auto rot_cos_sin_idx = indices_to_idx_3(rotated_tpig, cos_sin_strides);

  output[out_idx] = input[idx] * cos[cos_sin_idx] - input[rot_idx] * sin[cos_sin_idx];
  output[out_rot_idx] = input[rot_idx] * cos[rot_cos_sin_idx]
          + input[idx] * sin[rot_cos_sin_idx];
}

template<typename T>  
[[kernel]] void apply_rope_nd4(             
      device const void *input_b [[buffer(0)]],
      device const void *cos_b [[buffer(1)]],
      device const void *sin_b [[buffer(2)]],                 
      device void *output_b [[buffer(3)]],                        
      constant const size_t * shape [[buffer(4)]],
      constant const size_t * strides [[buffer(5)]],
      constant const size_t * cos_sin_strides [[buffer(6)]],
      constant const size_t * out_strides [[buffer(7)]],
      uint3 tpig[[thread_position_in_grid]]
) {
  device const T *input = (device const T *)input_b;
  device const T *cos = (device const T *)cos_b;
  device const T *sin = (device const T *)sin_b;

  device T* output = (device T *) output_b;

  uint3 rotated_tpig = tpig;
  rotated_tpig.x += shape[3] / 2;

  auto idx = indices_to_idx_4(tpig, shape, strides);
  auto rot_idx = indices_to_idx_4(rotated_tpig, shape, strides);
  auto out_idx = indices_to_idx_4(tpig, shape, out_strides);
  auto out_rot_idx = indices_to_idx_4(rotated_tpig, shape, out_strides);

  auto cos_sin_idx = indices_to_idx_4(tpig, shape, cos_sin_strides);
  auto rot_cos_sin_idx = indices_to_idx_4(rotated_tpig, shape, cos_sin_strides);

  output[out_idx] = input[idx] * cos[cos_sin_idx] - input[rot_idx] * sin[cos_sin_idx];
  output[out_rot_idx] = input[rot_idx] * cos[rot_cos_sin_idx]
          + input[idx] * sin[rot_cos_sin_idx];
}


typedef decltype(apply_rope_nd2<float>) apply_rope_nd2_t;
typedef decltype(apply_rope_nd3<float>) apply_rope_nd3_t;
typedef decltype(apply_rope_nd4<float>) apply_rope_nd4_t;

template [[host_name("nn_ops::apply_rope_nd2_f32")]] [[kernel]] apply_rope_nd2_t apply_rope_nd2<float>;
template [[host_name("nn_ops::apply_rope_nd3_f32")]] [[kernel]] apply_rope_nd3_t apply_rope_nd3<float>;
template [[host_name("nn_ops::apply_rope_nd4_f32")]] [[kernel]] apply_rope_nd4_t apply_rope_nd4<float>;

template [[host_name("nn_ops::apply_rope_nd2_f16")]] [[kernel]] apply_rope_nd2_t apply_rope_nd2<half>;
template [[host_name("nn_ops::apply_rope_nd3_f16")]] [[kernel]] apply_rope_nd3_t apply_rope_nd3<half>;
template [[host_name("nn_ops::apply_rope_nd4_f16")]] [[kernel]] apply_rope_nd4_t apply_rope_nd4<half>;


